
import numpy as np
from Env import FiniteStateFiniteActionMDP
import matplotlib.pyplot as plt

class FedQlearning_gen_adv:
    def __init__(self, mdp, total_episodes, num_agents, using_adv_min = 100):
        self.mdp = mdp
        self.total_episodes = total_episodes # total_episodes * num_agents = all episodes
        self.num_agents = num_agents
        self.V_func = np.zeros((self.mdp.H+1, self.mdp.S),dtype = np.float32) #estimated value function
        self.V_ref_func = np.zeros((self.mdp.H+1, self.mdp.S),dtype = np.float32) #used reference function
        self.trigger_times = 0 #number of round
        self.comm_episode_collection = []
        self.using_adv_min = using_adv_min

        self.V_sum_stage = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.float32)
        self.V2_sum_stage = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.float32)
        self.Vref_sum_all = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.float32)
        self.Vref2_sum_all = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.float32)
        self.Vadv_sum_stage = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.float32)
        self.Vadv2_sum_stage = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.float32)

        self.V_ref_trigger = np.zeros((self.mdp.H, self.mdp.S), dtype = np.int32)

        self.count_variance = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.float32)

        self.N = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.int32)
        self.n_previous_st = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.float32)
        self.n_current_st = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.float32)

        self.global_Q = np.full((self.mdp.H, self.mdp.S, self.mdp.A), self.mdp.H, dtype=np.float32)
        for i in range(self.mdp.H):
            self.global_Q[i,:,:] = self.mdp.H - i

        self.agent_N = np.zeros((num_agents, self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.int32)
        self.agent_V_sum = np.zeros((num_agents, self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.agent_V2_sum = np.zeros((num_agents, self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.agent_Vref_sum = np.zeros((num_agents, self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.agent_Vref2_sum = np.zeros((num_agents, self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.agent_Vadv_sum = np.zeros((num_agents, self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.agent_Vadv2_sum = np.zeros((num_agents, self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)

        self.regret = []
        self.cost = []
        
    def run_episode(self, agent_id):
        # Get the policy (actions for all states and steps)
        #V_func[h,s]
        event_triggered = False
        actions_policy = self.choose_action()
        state = self.mdp.reset()
        state_init = state
        rewards = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A))  # To store rewards for each state-step pair

        for step in range(self.mdp.H):
            # Select the action based on the agent's policy
            action = np.argmax(actions_policy[step, state])

            next_state, reward = self.mdp.step(action)

            # Increment visit count for the current state-action pair
            self.agent_N[agent_id, step, state, action] += 1            
            self.agent_V_sum[agent_id, step, state, action] += self.V_func[step+1, next_state]
            self.agent_V2_sum[agent_id, step, state, action] += self.V_func[step+1, next_state]**2
            self.agent_Vref_sum[agent_id, step, state, action] += self.V_ref_func[step+1, next_state]
            self.agent_Vref2_sum[agent_id, step, state, action] += (self.V_ref_func[step+1, next_state])**2
            self.agent_Vadv_sum[agent_id, step, state, action] += (self.V_func[step+1, next_state] - self.V_ref_func[step+1, next_state])
            self.agent_Vadv2_sum[agent_id, step, state, action] += (self.V_func[step+1, next_state] - self.V_ref_func[step+1, next_state])**2

            # Store the received reward
            rewards[step, state, action] = reward
            # Check if the event-triggered condition is met

            flag = self.check_sync_triggered(agent_id, step, state, action)
            if flag:
                event_triggered = True
            state = next_state
        return rewards, event_triggered, state_init

    def choose_action(self):
        actions = np.zeros([self.mdp.H, self.mdp.S, self.mdp.A])

        for step in range(self.mdp.H):
            for state in range(self.mdp.S):
                best_action = np.argmax(self.global_Q[step, state])
                actions[step, state, best_action] = 1

        return actions


    def check_sync_triggered(self, agent_id, step, state, action):
        previous_state_visit = self.n_previous_st[step, state, action]
        current_state_visit = self.n_current_st[step, state, action]
        threshold = 1
        if previous_state_visit > 0:
            if current_state_visit > (1-1/self.mdp.H)*previous_state_visit:
                threshold = round(np.floor(previous_state_visit/self.num_agents/self.mdp.H))
            else:
                threshold = round(np.ceil((previous_state_visit - current_state_visit)/self.num_agents))

        # Check if the visit count exceeds the threshold
        return self.agent_N[agent_id, step, state, action] >= threshold
    
    def check_stage_triggered(self, step, state, action):
        # Calculate the threshold for triggering the event

        previous_state_visit = self.n_previous_st[step, state, action]
        current_state_visit = self.n_current_st[step, state, action]
        
        return current_state_visit >= (self.num_agents * self.mdp.H)*(previous_state_visit == 0) + (
            1+1/self.mdp.H)* previous_state_visit
    
    def aggregate_data(self, policy_k, rewards): # after a round
        H, M = self.mdp.H, self.num_agents
        for h in range(H):
            for s in range(self.mdp.S):
                for a in range(self.mdp.A):
                    #print(policy_k[h, s])
                    if a != np.argmax(policy_k[h, s]) or self.agent_N[:, h, s, a].sum() == 0:
                        # No update required, retain previous Q-values
                        continue
                    else:
                        self.n_current_st[h, s, a] += self.agent_N[:, h, s, a].sum()
                        self.V_sum_stage[h, s, a] += self.agent_V_sum[:,h,s,a].sum()
                        self.V2_sum_stage[h, s, a] += self.agent_V2_sum[:,h,s,a].sum()
                        self.Vref_sum_all[h, s, a] += self.agent_Vref_sum[:, h, s, a].sum()
                        self.Vref2_sum_all[h, s, a] += self.agent_Vref2_sum[:, h, s, a].sum()
                        self.Vadv_sum_stage[h, s, a] += self.agent_Vadv_sum[:, h, s, a].sum()
                        self.Vadv2_sum_stage[h, s, a] += self.agent_Vadv2_sum[:, h, s, a].sum()
                        if self.check_stage_triggered(h,s,a):
                            self.N[h,s,a] += self.n_current_st[h, s, a]
                            Q1 = rewards[h,s,a] + self.V_sum_stage[h,s,a]/self.n_current_st[h, s, a] + np.sqrt(2*(H-h-1)*(H-h-1)/self.n_current_st[h, s, a])

                            sigma2_vref = self.Vref2_sum_all[h,s,a]/self.N[h,s,a] - (self.Vref_sum_all[h,s,a]/self.N[h,s,a])**2
                            sigma2_vadv = self.Vadv2_sum_stage[h,s,a]/self.n_current_st[h,s,a] - (
                                self.Vadv_sum_stage[h,s,a]/self.n_current_st[h,s,a])**2
                            if sigma2_vref < 0:
                                sigma2_vref = 0
                            if sigma2_vadv < 0:
                                sigma2_vadv = 0
                            Q2 = rewards[h,s,a] + self.Vref_sum_all[h,s,a]/self.N[h, s, a] + (self.Vadv_sum_stage[h,s,a]/self.n_current_st[h,s,a]) + 2*np.sqrt(
                                sigma2_vref/self.N[h, s, a]) + 2*np.sqrt(sigma2_vadv/self.n_current_st[h, s, a]) + self.mdp.H* (1/self.n_current_st[h,s,a]**(3/4)+1/self.N[h,s,a]**(3/4))

                            self.global_Q[h,s,a] = min([Q1, Q2, self.global_Q[h,s,a]])
                        
                            self.n_previous_st[h,s,a] = self.n_current_st[h, s, a]
                            self.n_current_st[h, s, a] = 0.0
                            self.V_sum_stage[h, s, a] = 0.0
                            self.V2_sum_stage[h, s, a] = 0.0
                            self.Vadv_sum_stage[h, s, a] = 0.0
                            self.Vadv2_sum_stage[h, s, a] = 0.0
        
        self.agent_N.fill(0)
        self.agent_V_sum.fill(0)
        self.agent_V2_sum.fill(0)
        self.agent_Vref_sum.fill(0)
        self.agent_Vref2_sum.fill(0)
        self.agent_Vadv_sum.fill(0)
        self.agent_Vadv2_sum.fill(0)
    
    def update_reference(self, h, s):
        if self.V_ref_trigger[h,s] == 1:
            return
        if self.N[h,s,:].sum() >= self.using_adv_min:
            self.V_ref_trigger[h,s] = 1
            self.V_ref_func[h,s] = self.V_func[h,s]
       
                            
    def learn(self):
        # cummulative regret per-agent
        self.regret_cum = 0
        best_value , best_policy, best_Q = self.mdp.best_gen()
        # Event-triggered termination flag
        event_triggered = False
        # Initialize a structure to store rewards (deterministic reward)
        rewards = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A))
        for h in range(self.mdp.H):
            for s in range(self.mdp.S):
                self.V_func[h,s] = max(self.global_Q[h, s, :])
                self.V_ref_func[h,s] = self.V_func[h,s]

        for h in range(self.mdp.H):
            for s in range(self.mdp.S):
                self.update_reference(h, s)
        actions_policy = self.choose_action()
        for episode in range(self.total_episodes):
            
            # Run one episode for each agent
            value = self.mdp.value_gen(actions_policy)
            for agent_id in range(self.num_agents):
                agent_reward, agent_event_triggered, state_init = self.run_episode(agent_id)
                self.regret_cum = self.regret_cum + best_value[state_init] - value[state_init]
                self.regret.append(self.regret_cum)

                for h in range(self.mdp.H):
                    for s in range(self.mdp.S):
                        a = np.argmax(actions_policy[h, s])
                        if rewards[h, s, a] == 0:
                            rewards[h, s, a] = agent_reward[h, s, a]

                if agent_event_triggered:
                    event_triggered = True

            # Globally aggregate and update policy if event-triggered termination occurred
            if event_triggered:
                self.trigger_times += 1
                self.comm_episode_collection.append(episode)
                self.aggregate_data(actions_policy, rewards)
                event_triggered = False
                actions_policy = self.choose_action()
                for h in range(self.mdp.H):
                    for s in range(self.mdp.S):
                        self.V_func[h,s] = max(self.global_Q[h, s, :])
                for h in range(self.mdp.H):
                    for s in range(self.mdp.S):
                        self.update_reference(h, s)
            self.cost.append(self.trigger_times)
        return best_Q, self.global_Q